Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2D Conv transpose support #311

Merged
merged 21 commits into from
Feb 6, 2019
Merged

2D Conv transpose support #311

merged 21 commits into from
Feb 6, 2019

Conversation

tejank10
Copy link
Contributor

@tejank10 tejank10 commented Jul 9, 2018

This PR adds support for 2D Conv Transpose for Flux, along with #54 in NNlib.jl

@tejank10 tejank10 changed the title 2D Conv transpose support [WIP] 2D Conv transpose support Jul 9, 2018
@tejank10
Copy link
Contributor Author

tejank10 commented Jul 13, 2018

gradtests are failing for ∇conv_data. Gradtest wrt data is passing. I found that the culprit is the gradient wrt the weight which is differing drastically.
EDIT: results of gradtests
gradtest(∇conv_data, rand(10, 3, 2), randn(2, 2, 3))

sum.(ngradient(f, xs...)) = (-15.724534273147583, 45.65564203262329)
sum.(data.(gradient(f, xs...))) = (-15.72453422801545, -41.27033625893009)

gradtest(∇conv_data, rand(10, 10, 3, 2), randn(2, 2, 2, 3))

sum.(ngradient(f, xs...)) = (-728.0231256484985, -326.6875114440918)
sum.(data.(gradient(f, xs...))) = (-728.0231184076845, 926.8081786572867)

gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 2, 3))

sum.(ngradient(f, xs...)) = (2724.2830810546875, 3647.9795532226562)
sum.(data.(gradient(f, xs...))) = (2724.2841834190576, -3632.449182033016)

@MikeInnes MikeInnes force-pushed the master branch 2 times, most recently from 65799d1 to 193c4de Compare September 5, 2018 15:53
@tejank10
Copy link
Contributor Author

tejank10 commented Sep 8, 2018

NNlib counterpart of this PR needs to be merged for the checks to pass

@tejank10 tejank10 changed the title [WIP] 2D Conv transpose support 2D Conv transpose support Sep 11, 2018
@Aniket1998
Copy link
Contributor

What's the status of this PR? Working on an InfoGAN model for the Flux model zoo, unable to write a version that would work on SVHN without conv transpose

@yuehhua
Copy link
Member

yuehhua commented Oct 19, 2018

It's there anything I can help? I'm waiting for this.

@MikeInnes
Copy link
Member

@tejank10 this PR looks generally good to me, mind just updating it?

test/tracker.jl Outdated
@@ -1,7 +1,11 @@
using Flux
using Flux.Tracker, Test, NNlib
<<<<<<< HEAD
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge artificat

stride = 1, pad = 0, dilation = 1) where {T,N} =
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)

ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> ConvTranspose((2, 2), 3=>3)
ERROR: UndefVarError: initn not defined
Stacktrace:
 [1] ConvTranspose(::Tuple{Int64,Int64}, ::Pair{Int64,Int64}, ::Function) at /home/vchuravy/.julia/packages/Flux/hguaX/src/layers/conv.jl:84 (repeats 2 times)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also:

julia> ConvTranspose((2, 2), 3=>64; init=Flux.glorot_uniform)(rand(4, 4, 3, 10))
ERROR: MethodError: no method matching ∇conv_data(::Array{Float64,4}, ::Array{Float32,4}; stride=(1, 1), pad=(0, 0), dilation=(1, 1))
Closest candidates are:
  ∇conv_data(::AbstractArray, ::TrackedArray; kw...) at /home/vchuravy/.julia/packages/Flux/hguaX/src/tracker/lib/array.jl:390
  ∇conv_data(::TrackedArray, ::AbstractArray; kw...) at /home/vchuravy/.julia/packages/Flux/hguaX/src/tracker/lib/array.jl:391
  ∇conv_data(::A<:AbstractArray, ::A<:AbstractArray; size, pad, stride, dilation, flipkernel) where A<:AbstractArray at /home/vchuravy/.julia/packages/NNlib/nf8OC/src/conv.jl:74
  ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing these out, I've fixed them in the latest commit.

@vchuravy
Copy link

vchuravy commented Dec 8, 2018

Thanks! I tried to use it again and all seems to work fine on the CPU, but in the GPU case I get:

julia> layer = ConvTranspose((2, 2), 1=>64, relu, stride=2) |> gpu                                               [18/1889]
ConvTranspose((2, 2)                                                                                                      
                                                                                                                          
julia> layer(train[1])                                                                                                    
ERROR: conversion to pointer not defined for CuArray{Float32,4}                                                           
Stacktrace:                                                                                                               
 [1] error(::String) at ./error.jl:33                                                                                     
 [2] unsafe_convert(::Type{Ptr{Float32}}, ::CuArray{Float32,4}) at ./pointer.jl:67                                        
 [3] pointer at ./abstractarray.jl:861 [inlined]                                                                          
 [4] #conv2d_grad_x!#44(::Tuple{Int64,Int64}, ::Tuple{Int64,Int64}, ::Tuple{Int64,Int64}, ::Int64, ::Int64, ::Function, ::CuArray{Float32,4}, ::CuArray{Float32,4}, ::CuArray{Float32,4}) at /home/vchuravy/.julia/packages/NNlib/3j02R/src/impl/co$v.jl:

@tejank10
Copy link
Contributor Author

tejank10 commented Dec 9, 2018

I just made a PR in CuArrays (#223) to reflect the changes in conv for CuArrays. With that, and latest conv_transpose branch of NNlib this should be working.

@NTimmons
Copy link

This looks like it never got merged. Is there no ConvTranspose in Flux?

@tejank10
Copy link
Contributor Author

@MikeInnes if #54 is fine then let's get it, and this PR merged?

@@ -77,6 +125,7 @@ struct DepthwiseConv{N,F,A,V}
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like perhaps an unrelated change?

@staticfloat
Copy link
Contributor

In my limited autoencoder testing, this seems to be working, but I'm worried about the failing gradient tests. @tejank10 do you know why they're failing?

@tejank10
Copy link
Contributor Author

tejank10 commented Feb 3, 2019

Trying tests with convAPI branch of CuArrays in this PR should get them passing.

@MikeInnes
Copy link
Member

I just merged that PR. Worth testing this again, figuring out if we need a CuArrays tag etc.

@staticfloat staticfloat closed this Feb 5, 2019
@staticfloat staticfloat reopened this Feb 5, 2019
@tejank10
Copy link
Contributor Author

tejank10 commented Feb 5, 2019

The gradtests are failing because it is using NNlib v0.4.3. Trying the tests with the NNlib-master seems to be working.

@MikeInnes
Copy link
Member

You should be able to add NNlib master to the manifest and get it tested that way.

@MikeInnes MikeInnes merged commit e8b2ec6 into FluxML:master Feb 6, 2019
@MikeInnes
Copy link
Member

Awesome stuff @tejank10, thanks!

@sidhantnagpal
Copy link

Thanks for working on this. It will be useful in implementing GANs using Flux.

@MikeInnes
Copy link
Member

Also thanks a lot to @staticfloat for the review!

@AzamatB
Copy link
Contributor

AzamatB commented Feb 6, 2019

Does this support only 2D images, as the name suggests, or 1D and 3D as well?

@tejank10
Copy link
Contributor Author

tejank10 commented Feb 6, 2019

It should support 1D and 3D as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants